a84fff
@@ -909,8 +909,8 @@
private static SharedResult extractSharedOptimizationInfo(ParseContext pctx,
       }
     }
 
-    discardableInputOps.addAll(gatherDPPBranchOps(pctx, optimizerCache, discardableInputOps));
-    discardableInputOps.addAll(gatherDPPBranchOps(pctx, optimizerCache, discardableOps));
+    discardableInputOps.addAll(gatherDPPBranchOps(pctx, optimizerCache,
+        Sets.union(discardableInputOps, discardableOps)));
     discardableInputOps.addAll(gatherDPPBranchOps(pctx, optimizerCache, retainableOps,
         discardableInputOps));
     return new SharedResult(retainableOps, discardableOps, discardableInputOps,
@@ -947,11 +947,7 @@
private static SharedResult extractSharedOptimizationInfo(ParseContext pctx,
             .get((TableScanOperator) op);
         for (Operator<?> dppSource : c) {
           // Remove the branches
-          Operator<?> currentOp = dppSource;
-          while (currentOp.getNumChild() <= 1) {
-            dppBranches.add(currentOp);
-            currentOp = currentOp.getParentOperators().get(0);
-          }
+          removeBranch(dppSource, dppBranches, ops);
         }
       }
     }
@@ -971,11 +967,7 @@
private static SharedResult extractSharedOptimizationInfo(ParseContext pctx,
               findAscendantWorkOperators(pctx, optimizerCache, dppSource);
           if (!Collections.disjoint(ascendants, discardedOps)) {
             // Remove branch
-            Operator<?> currentOp = dppSource;
-            while (currentOp.getNumChild() <= 1) {
-              dppBranches.add(currentOp);
-              currentOp = currentOp.getParentOperators().get(0);
-            }
+            removeBranch(dppSource, dppBranches, ops);
           }
         }
       }
@@ -983,6 +975,23 @@
private static SharedResult extractSharedOptimizationInfo(ParseContext pctx,
     return dppBranches;
   }
 
+  private static void removeBranch(Operator<?> currentOp, Set<Operator<?>> branchesOps,
+          Set<Operator<?>> discardableOps) {
+    if (currentOp.getNumChild() > 1) {
+      for (Operator<?> childOp : currentOp.getChildOperators()) {
+        if (!branchesOps.contains(childOp) && !discardableOps.contains(childOp)) {
+          return;
+        }
+      }
+    }
+    branchesOps.add(currentOp);
+    if (currentOp.getParentOperators() != null) {
+      for (Operator<?> parentOp : currentOp.getParentOperators()) {
+        removeBranch(parentOp, branchesOps, discardableOps);
+      }
+    }
+  }
+
   private static List<Operator<?>> compareAndGatherOps(ParseContext pctx,
           Operator<?> op1, Operator<?> op2) throws SemanticException {
     List<Operator<?>> result = new ArrayList<>();
